{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1000
        },
        "id": "2tyC5_SKKr5A",
        "outputId": "08ca6bdd-5941-4854-f201-da233cef6714"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Collecting d2l\n",
            "  Downloading d2l-0.17.3-py3-none-any.whl (82 kB)\n",
            "\u001b[?25l\r\u001b[K     |████                            | 10 kB 25.1 MB/s eta 0:00:01\r\u001b[K     |████████                        | 20 kB 8.1 MB/s eta 0:00:01\r\u001b[K     |████████████                    | 30 kB 5.4 MB/s eta 0:00:01\r\u001b[K     |████████████████                | 40 kB 5.1 MB/s eta 0:00:01\r\u001b[K     |███████████████████▉            | 51 kB 2.3 MB/s eta 0:00:01\r\u001b[K     |███████████████████████▉        | 61 kB 2.7 MB/s eta 0:00:01\r\u001b[K     |███████████████████████████▉    | 71 kB 2.9 MB/s eta 0:00:01\r\u001b[K     |███████████████████████████████▉| 81 kB 3.2 MB/s eta 0:00:01\r\u001b[K     |████████████████████████████████| 82 kB 610 kB/s \n",
            "\u001b[?25hRequirement already satisfied: jupyter==1.0.0 in /usr/local/lib/python3.7/dist-packages (from d2l) (1.0.0)\n",
            "Collecting matplotlib==3.3.3\n",
            "  Downloading matplotlib-3.3.3-cp37-cp37m-manylinux1_x86_64.whl (11.6 MB)\n",
            "\u001b[K     |████████████████████████████████| 11.6 MB 10.5 MB/s \n",
            "\u001b[?25hCollecting requests==2.25.1\n",
            "  Downloading requests-2.25.1-py2.py3-none-any.whl (61 kB)\n",
            "\u001b[K     |████████████████████████████████| 61 kB 9.4 MB/s \n",
            "\u001b[?25hCollecting pandas==1.2.2\n",
            "  Downloading pandas-1.2.2-cp37-cp37m-manylinux1_x86_64.whl (9.9 MB)\n",
            "\u001b[K     |████████████████████████████████| 9.9 MB 22.1 MB/s \n",
            "\u001b[?25hCollecting numpy==1.18.5\n",
            "  Downloading numpy-1.18.5-cp37-cp37m-manylinux1_x86_64.whl (20.1 MB)\n",
            "\u001b[K     |████████████████████████████████| 20.1 MB 14.6 MB/s \n",
            "\u001b[?25hRequirement already satisfied: qtconsole in /usr/local/lib/python3.7/dist-packages (from jupyter==1.0.0->d2l) (5.2.2)\n",
            "Requirement already satisfied: ipykernel in /usr/local/lib/python3.7/dist-packages (from jupyter==1.0.0->d2l) (4.10.1)\n",
            "Requirement already satisfied: jupyter-console in /usr/local/lib/python3.7/dist-packages (from jupyter==1.0.0->d2l) (5.2.0)\n",
            "Requirement already satisfied: nbconvert in /usr/local/lib/python3.7/dist-packages (from jupyter==1.0.0->d2l) (5.6.1)\n",
            "Requirement already satisfied: notebook in /usr/local/lib/python3.7/dist-packages (from jupyter==1.0.0->d2l) (5.3.1)\n",
            "Requirement already satisfied: ipywidgets in /usr/local/lib/python3.7/dist-packages (from jupyter==1.0.0->d2l) (7.6.5)\n",
            "Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib==3.3.3->d2l) (1.3.2)\n",
            "Requirement already satisfied: pillow>=6.2.0 in /usr/local/lib/python3.7/dist-packages (from matplotlib==3.3.3->d2l) (7.1.2)\n",
            "Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.3 in /usr/local/lib/python3.7/dist-packages (from matplotlib==3.3.3->d2l) (3.0.7)\n",
            "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.7/dist-packages (from matplotlib==3.3.3->d2l) (0.11.0)\n",
            "Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib==3.3.3->d2l) (2.8.2)\n",
            "Requirement already satisfied: pytz>=2017.3 in /usr/local/lib/python3.7/dist-packages (from pandas==1.2.2->d2l) (2018.9)\n",
            "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests==2.25.1->d2l) (2021.10.8)\n",
            "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests==2.25.1->d2l) (2.10)\n",
            "Requirement already satisfied: chardet<5,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests==2.25.1->d2l) (3.0.4)\n",
            "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests==2.25.1->d2l) (1.24.3)\n",
            "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.7/dist-packages (from python-dateutil>=2.1->matplotlib==3.3.3->d2l) (1.15.0)\n",
            "Requirement already satisfied: jupyter-client in /usr/local/lib/python3.7/dist-packages (from ipykernel->jupyter==1.0.0->d2l) (5.3.5)\n",
            "Requirement already satisfied: ipython>=4.0.0 in /usr/local/lib/python3.7/dist-packages (from ipykernel->jupyter==1.0.0->d2l) (5.5.0)\n",
            "Requirement already satisfied: tornado>=4.0 in /usr/local/lib/python3.7/dist-packages (from ipykernel->jupyter==1.0.0->d2l) (5.1.1)\n",
            "Requirement already satisfied: traitlets>=4.1.0 in /usr/local/lib/python3.7/dist-packages (from ipykernel->jupyter==1.0.0->d2l) (5.1.1)\n",
            "Requirement already satisfied: setuptools>=18.5 in /usr/local/lib/python3.7/dist-packages (from ipython>=4.0.0->ipykernel->jupyter==1.0.0->d2l) (57.4.0)\n",
            "Requirement already satisfied: simplegeneric>0.8 in /usr/local/lib/python3.7/dist-packages (from ipython>=4.0.0->ipykernel->jupyter==1.0.0->d2l) (0.8.1)\n",
            "Requirement already satisfied: decorator in /usr/local/lib/python3.7/dist-packages (from ipython>=4.0.0->ipykernel->jupyter==1.0.0->d2l) (4.4.2)\n",
            "Requirement already satisfied: pexpect in /usr/local/lib/python3.7/dist-packages (from ipython>=4.0.0->ipykernel->jupyter==1.0.0->d2l) (4.8.0)\n",
            "Requirement already satisfied: pickleshare in /usr/local/lib/python3.7/dist-packages (from ipython>=4.0.0->ipykernel->jupyter==1.0.0->d2l) (0.7.5)\n",
            "Requirement already satisfied: pygments in /usr/local/lib/python3.7/dist-packages (from ipython>=4.0.0->ipykernel->jupyter==1.0.0->d2l) (2.6.1)\n",
            "Requirement already satisfied: prompt-toolkit<2.0.0,>=1.0.4 in /usr/local/lib/python3.7/dist-packages (from ipython>=4.0.0->ipykernel->jupyter==1.0.0->d2l) (1.0.18)\n",
            "Requirement already satisfied: wcwidth in /usr/local/lib/python3.7/dist-packages (from prompt-toolkit<2.0.0,>=1.0.4->ipython>=4.0.0->ipykernel->jupyter==1.0.0->d2l) (0.2.5)\n",
            "Requirement already satisfied: ipython-genutils~=0.2.0 in /usr/local/lib/python3.7/dist-packages (from ipywidgets->jupyter==1.0.0->d2l) (0.2.0)\n",
            "Requirement already satisfied: nbformat>=4.2.0 in /usr/local/lib/python3.7/dist-packages (from ipywidgets->jupyter==1.0.0->d2l) (5.1.3)\n",
            "Requirement already satisfied: widgetsnbextension~=3.5.0 in /usr/local/lib/python3.7/dist-packages (from ipywidgets->jupyter==1.0.0->d2l) (3.5.2)\n",
            "Requirement already satisfied: jupyterlab-widgets>=1.0.0 in /usr/local/lib/python3.7/dist-packages (from ipywidgets->jupyter==1.0.0->d2l) (1.0.2)\n",
            "Requirement already satisfied: jupyter-core in /usr/local/lib/python3.7/dist-packages (from nbformat>=4.2.0->ipywidgets->jupyter==1.0.0->d2l) (4.9.1)\n",
            "Requirement already satisfied: jsonschema!=2.5.0,>=2.4 in /usr/local/lib/python3.7/dist-packages (from nbformat>=4.2.0->ipywidgets->jupyter==1.0.0->d2l) (4.3.3)\n",
            "Requirement already satisfied: pyrsistent!=0.17.0,!=0.17.1,!=0.17.2,>=0.14.0 in /usr/local/lib/python3.7/dist-packages (from jsonschema!=2.5.0,>=2.4->nbformat>=4.2.0->ipywidgets->jupyter==1.0.0->d2l) (0.18.1)\n",
            "Requirement already satisfied: importlib-resources>=1.4.0 in /usr/local/lib/python3.7/dist-packages (from jsonschema!=2.5.0,>=2.4->nbformat>=4.2.0->ipywidgets->jupyter==1.0.0->d2l) (5.4.0)\n",
            "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from jsonschema!=2.5.0,>=2.4->nbformat>=4.2.0->ipywidgets->jupyter==1.0.0->d2l) (3.10.0.2)\n",
            "Requirement already satisfied: importlib-metadata in /usr/local/lib/python3.7/dist-packages (from jsonschema!=2.5.0,>=2.4->nbformat>=4.2.0->ipywidgets->jupyter==1.0.0->d2l) (4.11.0)\n",
            "Requirement already satisfied: attrs>=17.4.0 in /usr/local/lib/python3.7/dist-packages (from jsonschema!=2.5.0,>=2.4->nbformat>=4.2.0->ipywidgets->jupyter==1.0.0->d2l) (21.4.0)\n",
            "Requirement already satisfied: zipp>=3.1.0 in /usr/local/lib/python3.7/dist-packages (from importlib-resources>=1.4.0->jsonschema!=2.5.0,>=2.4->nbformat>=4.2.0->ipywidgets->jupyter==1.0.0->d2l) (3.7.0)\n",
            "Requirement already satisfied: terminado>=0.8.1 in /usr/local/lib/python3.7/dist-packages (from notebook->jupyter==1.0.0->d2l) (0.13.1)\n",
            "Requirement already satisfied: Send2Trash in /usr/local/lib/python3.7/dist-packages (from notebook->jupyter==1.0.0->d2l) (1.8.0)\n",
            "Requirement already satisfied: jinja2 in /usr/local/lib/python3.7/dist-packages (from notebook->jupyter==1.0.0->d2l) (2.11.3)\n",
            "Requirement already satisfied: pyzmq>=13 in /usr/local/lib/python3.7/dist-packages (from jupyter-client->ipykernel->jupyter==1.0.0->d2l) (22.3.0)\n",
            "Requirement already satisfied: ptyprocess in /usr/local/lib/python3.7/dist-packages (from terminado>=0.8.1->notebook->jupyter==1.0.0->d2l) (0.7.0)\n",
            "Requirement already satisfied: MarkupSafe>=0.23 in /usr/local/lib/python3.7/dist-packages (from jinja2->notebook->jupyter==1.0.0->d2l) (2.0.1)\n",
            "Requirement already satisfied: testpath in /usr/local/lib/python3.7/dist-packages (from nbconvert->jupyter==1.0.0->d2l) (0.5.0)\n",
            "Requirement already satisfied: bleach in /usr/local/lib/python3.7/dist-packages (from nbconvert->jupyter==1.0.0->d2l) (4.1.0)\n",
            "Requirement already satisfied: entrypoints>=0.2.2 in /usr/local/lib/python3.7/dist-packages (from nbconvert->jupyter==1.0.0->d2l) (0.4)\n",
            "Requirement already satisfied: mistune<2,>=0.8.1 in /usr/local/lib/python3.7/dist-packages (from nbconvert->jupyter==1.0.0->d2l) (0.8.4)\n",
            "Requirement already satisfied: pandocfilters>=1.4.1 in /usr/local/lib/python3.7/dist-packages (from nbconvert->jupyter==1.0.0->d2l) (1.5.0)\n",
            "Requirement already satisfied: defusedxml in /usr/local/lib/python3.7/dist-packages (from nbconvert->jupyter==1.0.0->d2l) (0.7.1)\n",
            "Requirement already satisfied: packaging in /usr/local/lib/python3.7/dist-packages (from bleach->nbconvert->jupyter==1.0.0->d2l) (21.3)\n",
            "Requirement already satisfied: webencodings in /usr/local/lib/python3.7/dist-packages (from bleach->nbconvert->jupyter==1.0.0->d2l) (0.5.1)\n",
            "Requirement already satisfied: qtpy in /usr/local/lib/python3.7/dist-packages (from qtconsole->jupyter==1.0.0->d2l) (2.0.1)\n",
            "Installing collected packages: numpy, requests, pandas, matplotlib, d2l\n",
            "  Attempting uninstall: numpy\n",
            "    Found existing installation: numpy 1.21.5\n",
            "    Uninstalling numpy-1.21.5:\n",
            "      Successfully uninstalled numpy-1.21.5\n",
            "  Attempting uninstall: requests\n",
            "    Found existing installation: requests 2.23.0\n",
            "    Uninstalling requests-2.23.0:\n",
            "      Successfully uninstalled requests-2.23.0\n",
            "  Attempting uninstall: pandas\n",
            "    Found existing installation: pandas 1.3.5\n",
            "    Uninstalling pandas-1.3.5:\n",
            "      Successfully uninstalled pandas-1.3.5\n",
            "  Attempting uninstall: matplotlib\n",
            "    Found existing installation: matplotlib 3.2.2\n",
            "    Uninstalling matplotlib-3.2.2:\n",
            "      Successfully uninstalled matplotlib-3.2.2\n",
            "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
            "tensorflow 2.8.0 requires tf-estimator-nightly==2.8.0.dev2021122109, which is not installed.\n",
            "tensorflow 2.8.0 requires numpy>=1.20, but you have numpy 1.18.5 which is incompatible.\n",
            "tables 3.7.0 requires numpy>=1.19.0, but you have numpy 1.18.5 which is incompatible.\n",
            "google-colab 1.0.0 requires requests~=2.23.0, but you have requests 2.25.1 which is incompatible.\n",
            "datascience 0.10.6 requires folium==0.2.1, but you have folium 0.8.3 which is incompatible.\n",
            "albumentations 0.1.12 requires imgaug<0.2.7,>=0.2.5, but you have imgaug 0.2.9 which is incompatible.\u001b[0m\n",
            "Successfully installed d2l-0.17.3 matplotlib-3.3.3 numpy-1.18.5 pandas-1.2.2 requests-2.25.1\n"
          ]
        },
        {
          "data": {
            "application/vnd.colab-display-data+json": {
              "pip_warning": {
                "packages": [
                  "matplotlib",
                  "mpl_toolkits",
                  "numpy",
                  "pandas"
                ]
              }
            }
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "!pip install d2l\n",
        "import json\n",
        "import multiprocessing\n",
        "import os\n",
        "import torch\n",
        "from torch import nn\n",
        "from d2l import torch as d2l"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "OS1VsMq-LgtH"
      },
      "outputs": [],
      "source": [
        "d2l.DATA_HUB['bert.base'] = (d2l.DATA_URL + 'bert.base.torch.zip',\n",
        "                             '225d66f04cae318b841a13d32af3acc165f253ac')\n",
        "d2l.DATA_HUB['bert.small'] = (d2l.DATA_URL + 'bert.small.torch.zip',\n",
        "                              'c72329e68a732bef0452e4b96a1c341c8910f81f')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "hq9CVgnnLjnc"
      },
      "outputs": [],
      "source": [
        "def load_pretrained_model(pretrained_model, num_hiddens, ffn_num_hiddens,\n",
        "                          num_heads, num_layers, dropout, max_len, devices):\n",
        "    data_dir = d2l.download_extract(pretrained_model)\n",
        "    # 定义空词表以加载预定义词表\n",
        "    vocab = d2l.Vocab()\n",
        "    vocab.idx_to_token = json.load(open(os.path.join(data_dir,\n",
        "        'vocab.json')))\n",
        "    vocab.token_to_idx = {token: idx for idx, token in enumerate(\n",
        "        vocab.idx_to_token)}\n",
        "    bert = d2l.BERTModel(len(vocab), num_hiddens, norm_shape=[256],\n",
        "                         ffn_num_input=256, ffn_num_hiddens=ffn_num_hiddens,\n",
        "                         num_heads=4, num_layers=2, dropout=0.2,\n",
        "                         max_len=max_len, key_size=256, query_size=256,\n",
        "                         value_size=256, hid_in_features=256,\n",
        "                         mlm_in_features=256, nsp_in_features=256)\n",
        "    # 加载预训练BERT参数\n",
        "    bert.load_state_dict(torch.load(os.path.join(data_dir,\n",
        "                                                 'pretrained.params')))\n",
        "    return bert, vocab"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "ZUO4tFyPLpQM",
        "outputId": "63ae6f8c-c670-40ba-a8bb-474318d0e5e3"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Downloading ../data/bert.small.torch.zip from http://d2l-data.s3-accelerate.amazonaws.com/bert.small.torch.zip...\n"
          ]
        }
      ],
      "source": [
        "devices = d2l.try_all_gpus()\n",
        "bert, vocab = load_pretrained_model(\n",
        "    'bert.small', num_hiddens=256, ffn_num_hiddens=512, num_heads=4,\n",
        "    num_layers=2, dropout=0.1, max_len=512, devices=devices)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "LNATOVSXL16a"
      },
      "outputs": [],
      "source": [
        "class SNLIBERTDataset(torch.utils.data.Dataset):\n",
        "    def __init__(self, dataset, max_len, vocab=None):\n",
        "        all_premise_hypothesis_tokens = [[\n",
        "            p_tokens, h_tokens] for p_tokens, h_tokens in zip(\n",
        "            *[d2l.tokenize([s.lower() for s in sentences])\n",
        "              for sentences in dataset[:2]])]\n",
        "\n",
        "        self.labels = torch.tensor(dataset[2])\n",
        "        self.vocab = vocab\n",
        "        self.max_len = max_len\n",
        "        (self.all_token_ids, self.all_segments,\n",
        "         self.valid_lens) = self._preprocess(all_premise_hypothesis_tokens)\n",
        "        print('read ' + str(len(self.all_token_ids)) + ' examples')\n",
        "\n",
        "    def _preprocess(self, all_premise_hypothesis_tokens):\n",
        "        pool = multiprocessing.Pool(4)  # 使用4个进程\n",
        "        out = pool.map(self._mp_worker, all_premise_hypothesis_tokens)\n",
        "        all_token_ids = [\n",
        "            token_ids for token_ids, segments, valid_len in out]\n",
        "        all_segments = [segments for token_ids, segments, valid_len in out]\n",
        "        valid_lens = [valid_len for token_ids, segments, valid_len in out]\n",
        "        return (torch.tensor(all_token_ids, dtype=torch.long),\n",
        "                torch.tensor(all_segments, dtype=torch.long),\n",
        "                torch.tensor(valid_lens))\n",
        "\n",
        "    def _mp_worker(self, premise_hypothesis_tokens):\n",
        "        p_tokens, h_tokens = premise_hypothesis_tokens\n",
        "        self._truncate_pair_of_tokens(p_tokens, h_tokens)\n",
        "        tokens, segments = d2l.get_tokens_and_segments(p_tokens, h_tokens)\n",
        "        token_ids = self.vocab[tokens] + [self.vocab['<pad>']] \\\n",
        "                             * (self.max_len - len(tokens))\n",
        "        segments = segments + [0] * (self.max_len - len(segments))\n",
        "        valid_len = len(tokens)\n",
        "        return token_ids, segments, valid_len\n",
        "\n",
        "    def _truncate_pair_of_tokens(self, p_tokens, h_tokens):\n",
        "        # 为BERT输入中的'<CLS>'、'<SEP>'和'<SEP>'词元保留位置\n",
        "        while len(p_tokens) + len(h_tokens) > self.max_len - 3:\n",
        "            if len(p_tokens) > len(h_tokens):\n",
        "                p_tokens.pop()\n",
        "            else:\n",
        "                h_tokens.pop()\n",
        "\n",
        "    def __getitem__(self, idx):\n",
        "        return (self.all_token_ids[idx], self.all_segments[idx],\n",
        "                self.valid_lens[idx]), self.labels[idx]\n",
        "\n",
        "    def __len__(self):\n",
        "        return len(self.all_token_ids)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "CAIBVMGTL9Wu",
        "outputId": "815b8e0d-221e-4b0e-c78e-83b98c09ad04"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Downloading ../data/snli_1.0.zip from https://nlp.stanford.edu/projects/snli/snli_1.0.zip...\n",
            "read 549367 examples\n",
            "read 9824 examples\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py:481: UserWarning: This DataLoader will create 4 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.\n",
            "  cpuset_checked))\n"
          ]
        }
      ],
      "source": [
        "# 如果出现显存不足错误，请减少“batch_size”。在原始的BERT模型中，max_len=512\n",
        "batch_size, max_len, num_workers = 512, 128, d2l.get_dataloader_workers()\n",
        "data_dir = d2l.download_extract('SNLI')\n",
        "train_set = SNLIBERTDataset(d2l.read_snli(data_dir, True), max_len, vocab)\n",
        "test_set = SNLIBERTDataset(d2l.read_snli(data_dir, False), max_len, vocab)\n",
        "train_iter = torch.utils.data.DataLoader(train_set, batch_size, shuffle=True,\n",
        "                                   num_workers=num_workers)\n",
        "test_iter = torch.utils.data.DataLoader(test_set, batch_size,\n",
        "                                  num_workers=num_workers)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4BjQdXANMbNq"
      },
      "outputs": [],
      "source": [
        "class BERTClassifier(nn.Module):\n",
        "    def __init__(self, bert):\n",
        "        super(BERTClassifier, self).__init__()\n",
        "        self.encoder = bert.encoder\n",
        "        self.hidden = bert.hidden#encoder decoder权重是复制来的\n",
        "        self.output = nn.Linear(256, 3)\n",
        "\n",
        "    def forward(self, inputs):\n",
        "        tokens_X, segments_X, valid_lens_x = inputs\n",
        "        encoded_X = self.encoder(tokens_X, segments_X, valid_lens_x)\n",
        "        return self.output(self.hidden(encoded_X[:, 0, :]))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "1QDprIloMeYN"
      },
      "outputs": [],
      "source": [
        "net = BERTClassifier(bert)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 463
        },
        "id": "YFzExvxRMiBH",
        "outputId": "2f16247c-9c37-4d4f-b1cd-a3d981551bc1"
      },
      "outputs": [
        {
          "metadata": {
            "tags": null
          },
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py:481: UserWarning: This DataLoader will create 4 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.\n",
            "  cpuset_checked))\n"
          ]
        },
        {
          "output_type": "error",
          "ename": "ImportError",
          "evalue": "ignored",
          "traceback": [
            "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
            "\u001b[0;31mImportError\u001b[0m                               Traceback (most recent call last)",
            "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/IPython/core/formatters.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, obj)\u001b[0m\n\u001b[1;32m    332\u001b[0m                 \u001b[0;32mpass\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    333\u001b[0m             \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 334\u001b[0;31m                 \u001b[0;32mreturn\u001b[0m \u001b[0mprinter\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mobj\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    335\u001b[0m             \u001b[0;31m# Finally look for special method names\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    336\u001b[0m             \u001b[0mmethod\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mget_real_method\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mobj\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprint_method\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/IPython/core/pylabtools.py\u001b[0m in \u001b[0;36m<lambda>\u001b[0;34m(fig)\u001b[0m\n\u001b[1;32m    245\u001b[0m         \u001b[0mjpg_formatter\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfor_type\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mFigure\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mlambda\u001b[0m \u001b[0mfig\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mprint_figure\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfig\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'jpg'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    246\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0;34m'svg'\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mformats\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 247\u001b[0;31m         \u001b[0msvg_formatter\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfor_type\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mFigure\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mlambda\u001b[0m \u001b[0mfig\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mprint_figure\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfig\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'svg'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    248\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0;34m'pdf'\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mformats\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    249\u001b[0m         \u001b[0mpdf_formatter\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfor_type\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mFigure\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mlambda\u001b[0m \u001b[0mfig\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mprint_figure\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfig\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'pdf'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/IPython/core/pylabtools.py\u001b[0m in \u001b[0;36mprint_figure\u001b[0;34m(fig, fmt, bbox_inches, **kwargs)\u001b[0m\n\u001b[1;32m    123\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    124\u001b[0m     \u001b[0mbytes_io\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mBytesIO\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 125\u001b[0;31m     \u001b[0mfig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcanvas\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprint_figure\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbytes_io\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkw\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    126\u001b[0m     \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mbytes_io\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgetvalue\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    127\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0mfmt\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m'svg'\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/matplotlib/backend_bases.py\u001b[0m in \u001b[0;36mprint_figure\u001b[0;34m(self, filename, dpi, facecolor, edgecolor, orientation, format, bbox_inches, **kwargs)\u001b[0m\n\u001b[1;32m   2057\u001b[0m             \u001b[0mIf\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mbackend\u001b[0m\u001b[0;34m*\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mthen\u001b[0m \u001b[0mdetermine\u001b[0m \u001b[0ma\u001b[0m \u001b[0msuitable\u001b[0m \u001b[0mcanvas\u001b[0m \u001b[0;32mclass\u001b[0m \u001b[0;32mfor\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   2058\u001b[0m             \u001b[0msaving\u001b[0m \u001b[0mto\u001b[0m \u001b[0mformat\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mfmt\u001b[0m\u001b[0;34m*\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;34m-\u001b[0m \u001b[0meither\u001b[0m \u001b[0mthe\u001b[0m \u001b[0mcurrent\u001b[0m \u001b[0mcanvas\u001b[0m \u001b[0;32mclass\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mit\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2059\u001b[0;31m             \u001b[0msupports\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mfmt\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mwhatever\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m`\u001b[0m\u001b[0mget_registered_canvas_class\u001b[0m\u001b[0;31m`\u001b[0m \u001b[0mreturns\u001b[0m\u001b[0;34m;\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   2060\u001b[0m             \u001b[0mswitch\u001b[0m \u001b[0mthe\u001b[0m \u001b[0mfigure\u001b[0m \u001b[0mcanvas\u001b[0m \u001b[0mto\u001b[0m \u001b[0mthat\u001b[0m \u001b[0mcanvas\u001b[0m \u001b[0;32mclass\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   2061\u001b[0m         \"\"\"\n",
            "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/matplotlib/backend_bases.py\u001b[0m in \u001b[0;36m_get_output_canvas\u001b[0;34m(self, fmt)\u001b[0m\n\u001b[1;32m   1991\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmouse_grabber\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0max\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1992\u001b[0m             \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmouse_grabber\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1993\u001b[0;31m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1994\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mdraw\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1995\u001b[0m         \u001b[0;34m\"\"\"Render the `.Figure`.\"\"\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/matplotlib/backend_bases.py\u001b[0m in \u001b[0;36mget_registered_canvas_class\u001b[0;34m(format)\u001b[0m\n\u001b[1;32m    124\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    125\u001b[0m     \u001b[0mThe\u001b[0m \u001b[0mfollowing\u001b[0m \u001b[0mmethods\u001b[0m \u001b[0mmust\u001b[0m \u001b[0mbe\u001b[0m \u001b[0mimplemented\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mthe\u001b[0m \u001b[0mbackend\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mfull\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 126\u001b[0;31m     functionality (though just implementing :meth:`draw_path` alone would\n\u001b[0m\u001b[1;32m    127\u001b[0m     give a highly capable backend):\n\u001b[1;32m    128\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/lib/python3.7/importlib/__init__.py\u001b[0m in \u001b[0;36mimport_module\u001b[0;34m(name, package)\u001b[0m\n\u001b[1;32m    125\u001b[0m                 \u001b[0;32mbreak\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    126\u001b[0m             \u001b[0mlevel\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 127\u001b[0;31m     \u001b[0;32mreturn\u001b[0m \u001b[0m_bootstrap\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_gcd_import\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mlevel\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpackage\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlevel\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    128\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    129\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/lib/python3.7/importlib/_bootstrap.py\u001b[0m in \u001b[0;36m_gcd_import\u001b[0;34m(name, package, level)\u001b[0m\n",
            "\u001b[0;32m/usr/lib/python3.7/importlib/_bootstrap.py\u001b[0m in \u001b[0;36m_find_and_load\u001b[0;34m(name, import_)\u001b[0m\n",
            "\u001b[0;32m/usr/lib/python3.7/importlib/_bootstrap.py\u001b[0m in \u001b[0;36m_find_and_load_unlocked\u001b[0;34m(name, import_)\u001b[0m\n",
            "\u001b[0;32m/usr/lib/python3.7/importlib/_bootstrap.py\u001b[0m in \u001b[0;36m_load_unlocked\u001b[0;34m(spec)\u001b[0m\n",
            "\u001b[0;32m/usr/lib/python3.7/importlib/_bootstrap_external.py\u001b[0m in \u001b[0;36mexec_module\u001b[0;34m(self, module)\u001b[0m\n",
            "\u001b[0;32m/usr/lib/python3.7/importlib/_bootstrap.py\u001b[0m in \u001b[0;36m_call_with_frames_removed\u001b[0;34m(f, *args, **kwds)\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/matplotlib/backends/backend_svg.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m     16\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mmatplotlib\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mmpl\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     17\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mmatplotlib\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mcbook\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 18\u001b[0;31m from matplotlib.backend_bases import (\n\u001b[0m\u001b[1;32m     19\u001b[0m      \u001b[0m_Backend\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_check_savefig_extra_args\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mFigureCanvasBase\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mFigureManagerBase\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     20\u001b[0m      RendererBase)\n",
            "\u001b[0;31mImportError\u001b[0m: cannot import name '_check_savefig_extra_args' from 'matplotlib.backend_bases' (/usr/local/lib/python3.7/dist-packages/matplotlib/backend_bases.py)"
          ]
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<Figure size 252x180 with 1 Axes>"
            ]
          },
          "metadata": {}
        }
      ],
      "source": [
        "lr, num_epochs = 1e-4, 5\n",
        "trainer = torch.optim.Adam(net.parameters(), lr=lr)\n",
        "loss = nn.CrossEntropyLoss(reduction='none')\n",
        "d2l.train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs,\n",
        "    devices)"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "name": "BERT微调代码.ipynb",
      "provenance": [],
      "collapsed_sections": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}